#!/usr/bin/env python
import os
import argparse

# fmt: off
def get_cmd(
    task, sub_task, model_tag, model_path, gpu, data_num, bs, lr, source_length,target_length,patience,epoch,warmup,
    model_dir,summary_dir,res_fn,max_steps=None,save_steps=None,log_steps=None,parse_as_tree=False,grad_acc_steps=1,
):
    script = "exp_with_args_trees.sh" if parse_as_tree else "exp_with_args.sh"
    if max_steps is None:
        cmd_str = "bash %s %s %s %s %s %d %d %d %d %d %d %d %d %s %s %s %d" % (
            script,task,sub_task,model_tag,model_path,data_num,bs,lr,source_length,
            target_length,patience,epoch,warmup,model_dir,summary_dir, res_fn, grad_acc_steps,
        )

        print(cmd_str)
    else:
        # gets invoked with multi-task
        cmd_str = "bash %s %s %s %s %d %d %d %d %d %d %d %d %d %s %s %s %d %d %d" % (
            script,
            task,
            sub_task,
            model_tag,
            model_path,
            data_num,
            bs,
            lr,
            source_length,
            target_length,
            patience,
            epoch,
            warmup,
            model_dir,
            summary_dir,
            res_fn,
            max_steps,
            save_steps,
            log_steps,
        )
    return cmd_str


def get_args_by_task_model_string(task, sub_task, model_tag):
    if task == "translate":
        # java-cs: Read 10300 examples, avg src len: 13, avg trg len: 15, max src len: 136, max trg len: 118
        # [TOKENIZE] avg src len: 45, avg trg len: 56, max src len: 391, max trg len: 404

        # Trees stats
        # java-cs: Read 10300 examples, avg src len: 87, avg trg len: 135, max src len: 833, max trg len: 951
        # [TOKENIZE] avg src len: 98, avg trg len: 155, max src len: 963, max trg len: 1156

        # src_len = 320
        # trg_len = 256
        src_len, trg_len = 512, 512
        epoch = 100
        patience = 5
    elif task == "summarize":
        # ruby: Read 24927 examples, avg src len: 66, avg trg len: 12, max src len: 501, max trg len: 146
        # [TOKENIZE] avg src len: 100, avg trg len: 13, max src len: 1250, max trg len: 161
        # Python: Read 251820 examples, avg src len: 100, avg trg len: 11, max src len: 512, max trg len: 222
        # [TOKENIZE] avg src len: 142, avg trg len: 12, max src len: 2016, max trg len: 245
        # Javascript: Read 58025 examples, avg src len: 114, avg trg len: 11, max src len: 512, max trg len: 165
        # [TOKENIZE] avg src len: 136, avg trg len: 12, max src len: 3016, max trg len: 177
        src_len = 256
        trg_len = 128
        epoch = 15
        patience = 2
    elif task == "refine":
        # small: Read 46680 examples, avg src len: 31, avg trg len: 28, max src len: 50, max trg len: 50
        # [TOKENIZE] avg src len: 50, avg trg len: 45, max src len: 129, max trg len: 121
        # medium:  Read 52364 examples, avg src len: 74, avg trg len: 73, max src len: 100, max trg len: 100
        # [TOKENIZE] avg src len: 117, avg trg len: 114, max src len: 238, max trg len: 238
        if sub_task == "small":
            src_len = 130
            trg_len = 120
        elif sub_task == "medium":
            src_len = 240
            trg_len = 240
        epoch = 50
        patience = 5
    elif task == "concode":
        # Read 100000 examples, avg src len: 71, avg trg len: 26, max src len: 567, max trg len: 140
        # [TOKENIZE] avg src len: 213, avg trg len: 33, max src len: 2246, max trg len: 264
        src_len = 320
        trg_len = 150
        epoch = 30
        patience = 3
    elif task == "defect":
        # Read 21854 examples, avg src len: 187, avg trg len: 1, max src len: 12195, max trg len: 1
        # [TOKENIZE] avg src len: 597, avg trg len: 1, max src len: 41447, max trg len: 1
        src_len = 1024
        trg_len = 3
        epoch = 10
        patience = 2
    elif task == "clone":
        # Read 901028 examples, avg src len: 120, avg trg len: 123, max src len: 5270, max trg len: 5270
        # [TOKENIZE] avg src len: 318, avg trg len: 323, max src len: 15111, max trg len: 15111
        src_len = 400
        trg_len = 400
        epoch = 1
        patience = 2
    elif task == "mathqa":
        # Read 19209 examples, avg src len: 50, avg trg len: 31, max src len: 408, max trg len: 251
        # [TOKENIZE] avg src len: 66, avg trg len: 64, max src len: 487, max trg len: 463
        src_len = 512
        trg_len = 512
        epoch = 50
        patience = 5
    elif task == "fixeval":
        # Read 50000 examples, avg src len: 254, avg trg len: 259, max src len: 2283, max trg len: 2156
        # [TOKENIZE] avg src len: 269, avg trg len: 273, max src len: 2515, max trg len: 2216
        src_len = 512
        trg_len = 512
        epoch = 20
        patience = 2
    elif task == "mbpp":
        # Read 374 examples, avg src len: 14, avg trg len: 23, max src len: 47, max trg len: 129
        # [TOKENIZE] avg src len: 16, avg trg len: 71, max src len: 49, max trg len: 304
        src_len = 64
        trg_len = 310
        epoch = 50
        patience = 5
    elif task == "conala":
        # Read 1879 examples, avg src len: 10, avg trg len: 4, max src len: 29, max trg len: 19
        # [TOKENIZE] avg src len: 15, avg trg len: 15, max src len: 62, max trg len: 69
        src_len = 32
        trg_len = 128
        epoch = 50
        patience = 5
    elif task == "avatar":
        # Read 3391 examples, avg src len: 111, avg trg len: 88, max src len: 427, max trg len: 404
        # [TOKENIZE] avg src len: 119, avg trg len: 118, max src len: 542, max trg len: 563
        src_len = 512
        trg_len = 512
        epoch = 50
        patience = 5

    grad_acc_steps = 1

    if "codet5_small" in model_tag:
        bs = 32
        if (
            task == "summarize"
            or task == "translate"
            or (task == "refine" and sub_task == "small")
        ):
            bs = 64
        elif task == "clone":
            bs = 25
    elif "codet5_large" in model_tag:
        bs = 8
    else:
        bs = 32
        if task == "translate":
            # bs = 25
            bs = 12
        elif task == "summarize":
            bs = 48
        elif task == "clone":
            if model_tag in ["codebert", "roberta"]:
                bs = 16
            else:
                bs = 10
        elif task == "defect":
            if src_len == 1024:
                bs = 8
        elif task == "mathqa":
            bs = 12
        elif task == "fixeval":
            bs = 12
            grad_acc_steps = 3
        elif task in ["mbpp", "avatar"]:
            bs = 8

    lr = 5
    if task == "concode":
        lr = 10
    elif task == "defect":
        lr = 2
    return bs, lr, src_len, trg_len, patience, epoch, grad_acc_steps


def get_args_by_task_model_trees(task, sub_task, model_tag):
    if task == "translate":
        # Trees
        # java-cs: Read 10300 examples, avg src len: 87, avg trg len: 135, max src len: 833, max trg len: 951
        # [TOKENIZE] avg src len: 98, avg trg len: 155, max src len: 963, max trg len: 1156

        src_len = 512
        trg_len = 512
        epoch = 100
        patience = 5
    elif task == "summarize":
        # ruby: Read 24927 examples, avg src len: 66, avg trg len: 12, max src len: 501, max trg len: 146
        # [TOKENIZE] avg src len: 100, avg trg len: 13, max src len: 1250, max trg len: 161
        # Python: Read 251820 examples, avg src len: 100, avg trg len: 11, max src len: 512, max trg len: 222
        # [TOKENIZE] avg src len: 142, avg trg len: 12, max src len: 2016, max trg len: 245
        # Javascript: Read 58025 examples, avg src len: 114, avg trg len: 11, max src len: 512, max trg len: 165
        # [TOKENIZE] avg src len: 136, avg trg len: 12, max src len: 3016, max trg len: 177
        src_len = 256
        trg_len = 128
        epoch = 15
        patience = 2
    elif task == "refine":
        # Trees
        # small: Read 46678 examples, avg src len: 84, avg trg len: 76, max src len: 313, max trg len: 291
        # [TOKENIZE] avg src len: 102, avg trg len: 92, max src len: 331, max trg len: 309

        # Trees:
        # medium: Read 52359 examples, avg src len: 196, avg trg len: 191, max src len: 598, max trg len: 598
        # [TOKENIZE] avg src len: 238, avg trg len: 232, max src len: 624, max trg len: 624
        if sub_task == "small":
            src_len = 330
            trg_len = 310
        elif sub_task == "medium":
            src_len = 512
            trg_len = 512
        epoch = 50
        patience = 5
    elif task == "concode":
        # Trees
        # Read 99996 examples, avg src len: 71, avg trg len: 67, max src len: 567, max trg len: 369
        # [TOKENIZE] avg src len: 213, avg trg len: 76, max src len: 2246, max trg len: 407

        src_len = 320
        trg_len = 410
        epoch = 30
        patience = 3
    elif task == "defect":
        # Read 21854 examples, avg src len: 187, avg trg len: 1, max src len: 12195, max trg len: 1
        # [TOKENIZE] avg src len: 597, avg trg len: 1, max src len: 41447, max trg len: 1

        # Tree stats
        # Read 21854 examples, avg src len: 535, avg trg len: 1, max src len: 33680, max trg len: 1
        # [TOKENIZE] avg src len: 1237, avg trg len: 1, max src len: 78554, max trg len: 1

        src_len = 1024
        trg_len = 3
        epoch = 10
        patience = 2
    elif task == "clone":
        # Read 901028 examples, avg src len: 120, avg trg len: 123, max src len: 5270, max trg len: 5270
        # [TOKENIZE] avg src len: 318, avg trg len: 323, max src len: 15111, max trg len: 15111
        src_len = 400
        trg_len = 400
        epoch = 1
        patience = 2
    elif task == "mathqa":
        # Read 19209 examples, avg src len: 50, avg trg len: 125, max src len: 408, max trg len: 908
        # [TOKENIZE] avg src len: 66, avg trg len: 149, max src len: 487, max trg len: 1070
        src_len = 512
        trg_len = 512
        epoch = 50
        patience = 5
    elif task == "fixeval":
        # Read 49881 examples, avg src len: 31, avg trg len: 21, max src len: 2811, max trg len: 2067
        # [TOKENIZE] avg src len: 44, avg trg len: 33, max src len: 3119, max trg len: 3039
        src_len = 512
        trg_len = 512
        epoch = 20
        patience = 2
    elif task == "mbpp":

        # Read 374 examples, avg src len: 14, avg trg len: 145, max src len: 47, max trg len: 646
        # [TOKENIZE] avg src len: 16, avg trg len: 156, max src len: 49, max trg len: 652
        src_len = 64
        trg_len = 512
        epoch = 50
        patience = 5

    elif task == "conala":
        # Read 1879 examples, avg src len: 10, avg trg len: 41, max src len: 29, max trg len: 151
        # [TOKENIZE] avg src len: 15, avg trg len: 47, max src len: 62, max trg len: 167
        src_len = 32
        trg_len = 170
        epoch = 50
        patience = 5
    elif task == "avatar":
        # Read 3390 examples, avg src len: 291, avg trg len: 272, max src len: 1069, max trg len: 1170
        # [TOKENIZE] avg src len: 306, avg trg len: 281, max src len: 1319, max trg len: 1261

        src_len = 512
        trg_len = 512
        epoch = 50
        patience = 5

    grad_acc_steps = 1

    if "codet5_small" in model_tag:
        bs = 32
        if (
            task == "summarize"
            or task == "translate"
            or (task == "refine" and sub_task == "small")
        ):
            bs = 64
        elif task == "clone":
            bs = 25
    elif "codet5_large" in model_tag:
        bs = 8
    else:
        bs = 32
        if task == "translate":
            bs = 12
        elif task == "refine":
            bs = 16 if sub_task == "small" else 12
        elif task == "summarize":
            bs = 48
        elif task == "clone":
            if model_tag in ["codebert", "roberta"]:
                bs = 16
            else:
                bs = 10

        elif task == "defect":
            if src_len == 1024:
                bs = 8
        elif task == "concode":
            bs = 16
        elif task == "mathqa":
            bs = 12
        elif task == "fixeval":
            bs = 12
            grad_acc_steps = 3
        elif task in ["mbpp", "avatar"]:
            bs = 8

    lr = 5
    if task == "concode":
        lr = 10
    elif task == "defect":
        lr = 2
    return bs, lr, src_len, trg_len, patience, epoch, grad_acc_steps


def get_args_by_task_model(task, sub_task, model_tag, parse_as_tree):
    if parse_as_tree:
        return get_args_by_task_model_trees(task, sub_task, model_tag)
    else:
        return get_args_by_task_model_string(task, sub_task, model_tag)


def run_one_exp(args):
    bs, lr, src_len, trg_len, patience, epoch, grad_acc_steps = get_args_by_task_model(
        args.task, args.sub_task, args.model_tag, args.parse_as_tree
    )
    print("============================Start Running==========================")
    cmd_str = get_cmd(
        task=args.task,
        sub_task=args.sub_task,
        model_tag=args.model_tag,
        model_path=args.model_path,
        gpu=args.gpu,
        data_num=args.data_num,
        bs=bs,
        lr=lr,
        source_length=src_len,
        target_length=trg_len,
        patience=patience,
        epoch=epoch,
        warmup=1000,
        model_dir=args.model_dir,
        summary_dir=args.summary_dir,
        res_fn="{}/{}_{}.txt".format(args.res_dir, args.task, args.model_tag),
        parse_as_tree=args.parse_as_tree,
        grad_acc_steps=grad_acc_steps
    )
    print("%s\n" % cmd_str)
    os.system(cmd_str)


def run_multi_task_exp(args):
    # Total train data num = 1149722 (for all five tasks)
    if "codet5_small" in args.model_tag:
        bs, lr, max_steps, save_steps, log_steps = 60, 5, 600000, 20000, 100
    else:
        bs, lr, max_steps, save_steps, log_steps = 25, 5, 800000, 20000, 100

    if args.data_num != -1:
        max_steps, save_steps, log_steps = 1000, 200, 50
    print("============================Start Running==========================")
    cmd_str = get_cmd(
        task="multi_task",
        sub_task="none",
        model_tag=args.model_tag,
        gpu=args.gpu,
        data_num=args.data_num,
        bs=bs,
        lr=lr,
        source_length=-1,
        target_length=-1,
        patience=-1,
        epoch=-1,
        warmup=1000,
        model_dir=args.model_dir,
        summary_dir=args.summary_dir,
        res_fn="{}/multi_task_{}.txt".format(args.res_dir, args.model_tag),
        max_steps=max_steps,
        save_steps=save_steps,
        log_steps=log_steps,
    )
    print("%s\n" % cmd_str)
    os.system(cmd_str)


def get_sub_tasks(task):
    if task == "summarize":
        sub_tasks = ["ruby", "javascript", "go", "python", "java", "php"]
    elif task == "translate":
        sub_tasks = ["java-cs", "cs-java"]
    elif task == "refine":
        sub_tasks = ["small", "medium"]
    elif task in ["concode", "defect", "clone", "multi_task", "mathqa", "mbpp", "conala"]:
        sub_tasks = ["none"]
    elif task in ["fixeval"]:
        sub_tasks = ["java", "python"]
    elif task in ["avatar"]:
        sub_tasks = ["java-python", "python-java"]
    return sub_tasks


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_tag",
        type=str,
        default="codet5_base",
        choices=[
            "roberta",
            "codebert",
            "bart_base",
            "codet5_small",
            "codet5_base",
            "codet5_large",
            "codet5_custom",
        ],
    )

    parser.add_argument("--model_path", type=str, required=False, default="placeholder")
    parser.add_argument("--parse_as_tree", action="store_true")

    parser.add_argument(
        "--task",
        type=str,
        default="summarize",
        choices=[
            "summarize",
            "concode",
            "translate",
            "refine",
            "defect",
            "clone",
            "multi_task",
            "mathqa",
            "fixeval",
            "mbpp",
            "conala",
            "avatar"
        ],
    )
    parser.add_argument("--sub_task", type=str, default="ruby")
    parser.add_argument(
        "--res_dir",
        type=str,
        default="results",
        help="directory to save fine-tuning results",
    )
    parser.add_argument(
        "--model_dir",
        type=str,
        default="saved_models",
        help="directory to save fine-tuned models",
    )
    parser.add_argument(
        "--summary_dir",
        type=str,
        default="tensorboard",
        help="directory to save tensorboard summary",
    )
    parser.add_argument(
        "--data_num",
        type=int,
        default=-1,
        help="number of data instances to use, -1 for full data",
    )
    parser.add_argument(
        "--gpu", type=int, default=0, help="index of the gpu to use in a cluster"
    )
    args = parser.parse_args()

    if not os.path.exists(args.res_dir):
        os.makedirs(args.res_dir)

    assert args.sub_task in get_sub_tasks(args.task)
    if args.task != "multi_task":
        run_one_exp(args)
    else:
        run_multi_task_exp(args)
